# src/vol5_k2m_cc/optics.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import numpy as np

@dataclass
class OpticsCfg:
    lambda_sweep: List[float] = field(default_factory=lambda: [0.2, 0.5, 1.0])
    index_blur_sigma: int = 3
    lensing_b_min: int = 12
    lensing_b_max: int = 64
    lensing_b_n: int = 32

# ---- simple separable Gaussian blur (no SciPy required) ----
def _gaussian_kernel_1d(sigma: float, truncate: float = 3.0) -> np.ndarray:
    if sigma <= 0:
        # degenerate kernel -> identity
        return np.array([1.0], dtype=np.float64)
    radius = max(1, int(truncate * sigma + 0.5))
    x = np.arange(-radius, radius + 1, dtype=np.float64)
    k = np.exp(-0.5 * (x / float(sigma)) ** 2)
    k /= k.sum()
    return k

def _blur_separable(img: np.ndarray, sigma: float) -> np.ndarray:
    if sigma <= 0:
        return img.astype(np.float64, copy=False)
    k = _gaussian_kernel_1d(float(sigma))
    pad = len(k) // 2
    a = np.pad(img.astype(np.float64, copy=False), ((pad, pad), (pad, pad)), mode="reflect")

    # convolve rows
    tmp = np.empty_like(a)
    for y in range(a.shape[0]):
        tmp[y, :] = np.convolve(a[y, :], k, mode="same")

    # convolve cols
    out = np.empty_like(tmp)
    for x in range(tmp.shape[1]):
        out[:, x] = np.convolve(tmp[:, x], k, mode="same")

    return out[pad:-pad, pad:-pad]

# ---- small utilities ----
def _central_grad_y(arr: np.ndarray) -> np.ndarray:
    # central difference along y, shape -> (L-2, L)
    return 0.5 * (arr[2:, :] - arr[:-2, :])

def _linfit(x: np.ndarray, y: np.ndarray) -> Tuple[float, float, float]:
    x = x.astype(np.float64, copy=False)
    y = y.astype(np.float64, copy=False)
    xm = x.mean()
    ym = y.mean()
    dx = x - xm
    dy = y - ym
    denom = float(np.dot(dx, dx))
    if denom == 0.0:
        return 0.0, ym, float("nan")
    m = float(np.dot(dx, dy) / denom)
    c = float(ym - m * xm)
    yhat = m * x + c
    ss_res = float(np.sum((y - yhat) ** 2))
    ss_tot = float(np.sum((y - ym) ** 2))
    r2 = float("nan") if ss_tot == 0.0 else float(1.0 - ss_res / ss_tot)
    return m, c, r2

# ---- main computation ----
def compute_alpha_fields(mask: np.ndarray, cfg: OpticsCfg) -> Dict[str, float]:
    """
    Build an index map n = 1 + λ * blur(mask) and compute small-angle deflection
    for rays traveling along +x. For each impact parameter b, the deflection
    magnitude is approximated by integrating |∂n/∂y| along x on the two horizontal
    lines y = cy ± b, then averaging. Fit α(b) vs 1/b and return slopes/R² per λ,
    plus the best λ by R².
    """
    L = int(mask.shape[0])
    assert mask.shape == (L, L), f"mask shape must be (L,L); got {mask.shape}"
    src = mask.astype(np.float64, copy=False)
    if src.max() > 0:
        src /= src.max()

    base = _blur_separable(src, float(cfg.index_blur_sigma))

    # impact parameters
    b_vals = np.linspace(int(cfg.lensing_b_min), int(cfg.lensing_b_max), int(cfg.lensing_b_n))
    b_idx = np.unique(np.clip(np.round(b_vals).astype(int), 1, L - 2))  # stay within [1, L-2] for grad indexing
    cy = 0.5 * (L - 1)

    results: Dict[str, float] = {}
    best_r2 = -np.inf
    best_lambda = None

    for lam in cfg.lambda_sweep:
        n = 1.0 + float(lam) * base
        gy = _central_grad_y(n)  # shape (L-2, L); corresponds to original y indices [1..L-2]

        alpha_list = []
        for bi in b_idx:
            # map original y index -> gy index (y-1)
            y1 = int(np.clip(int(round(cy + bi)), 1, L - 2)) - 1
            y2 = int(np.clip(int(round(cy - bi)), 1, L - 2)) - 1
            # use magnitude (robust to polarity); integrate along x
            a1 = float(np.sum(np.abs(gy[y1, :])))
            a2 = float(np.sum(np.abs(gy[y2, :])))
            alpha_list.append(0.5 * (a1 + a2))
        alpha = np.array(alpha_list, dtype=np.float64)

        x = 1.0 / b_idx.astype(np.float64)
        m, c, r2 = _linfit(x, alpha)

        results[f"alpha_slope_lambda{lam}"] = float(m)
        results[f"alpha_r2_lambda{lam}"] = float(r2)

        if np.isfinite(r2) and r2 > best_r2:
            best_r2 = r2
            best_lambda = lam

    results["alpha_best_lambda"] = f"lambda_{best_lambda}" if best_lambda is not None else None
    results["alpha_best_r2"] = float(best_r2) if np.isfinite(best_r2) else float("nan")
    return results
